/*
 * Decompiled with CFR 0.152.
 */
package org.languagetool.rules.spelling.suggestions;

import club.sk1er.org.apache.commons.lang3.tuple.Pair;
import club.sk1er.org.apache.commons.pool2.BaseKeyedPooledObjectFactory;
import club.sk1er.org.apache.commons.pool2.KeyedObjectPool;
import club.sk1er.org.apache.commons.pool2.PooledObject;
import club.sk1er.org.apache.commons.pool2.impl.DefaultPooledObject;
import club.sk1er.org.apache.commons.pool2.impl.GenericKeyedObjectPool;
import java.io.FileNotFoundException;
import java.io.InputStream;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;
import java.util.stream.Collectors;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.jetbrains.annotations.NotNull;
import org.languagetool.AnalyzedSentence;
import org.languagetool.JLanguageTool;
import org.languagetool.Language;
import org.languagetool.languagemodel.LanguageModel;
import org.languagetool.rules.SuggestedReplacement;
import org.languagetool.rules.spelling.suggestions.SuggestionsOrdererFeatureExtractor;
import org.languagetool.rules.spelling.suggestions.SuggestionsRanker;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class XGBoostSuggestionsOrderer
extends SuggestionsOrdererFeatureExtractor
implements SuggestionsRanker {
    private static final Logger logger = LoggerFactory.getLogger(XGBoostSuggestionsOrderer.class);
    private static final KeyedObjectPool<Language, Booster> modelPool = new GenericKeyedObjectPool<Language, Booster>(new BaseKeyedPooledObjectFactory<Language, Booster>(){

        @Override
        public Booster create(Language language) throws Exception {
            String modelPath = "";
            try {
                modelPath = XGBoostSuggestionsOrderer.getModelPath(language);
                InputStream savedModel = JLanguageTool.getDataBroker().getFromResourceDirAsStream(modelPath);
                return XGBoost.loadModel(savedModel);
            }
            catch (FileNotFoundException e) {
                logger.warn(String.format("Could not load suggestion ranking model at '%s'. Platform might be unsupported by the official XGBoost maven package, or model might be missing/corrupted.", modelPath), e);
                return null;
            }
        }

        @Override
        public PooledObject<Booster> wrap(Booster booster) {
            return new DefaultPooledObject<Booster>(booster);
        }
    });
    private static final Map<String, Float> autoCorrectThreshold = new HashMap<String, Float>();
    private static final Map<String, List<Integer>> modelClasses = new HashMap<String, List<Integer>>();
    private static final Map<String, Integer> candidateFeatureCount = new HashMap<String, Integer>();
    private static final Map<String, Integer> matchFeatureCount = new HashMap<String, Integer>();
    private boolean modelAvailableForLanguage = false;
    private static boolean xgboostNotSupported = false;

    @NotNull
    private static String getModelPath(Language language) {
        return "/" + language.getShortCode() + "/spelling_correction_model.bin";
    }

    public static void setAutoCorrectThresholdForLanguage(Language lang, float value) {
        autoCorrectThreshold.replace(lang.getShortCodeWithCountryAndVariant(), Float.valueOf(value));
    }

    public XGBoostSuggestionsOrderer(Language lang, LanguageModel languageModel) {
        super(lang, languageModel);
        String langCode = lang.getShortCodeWithCountryAndVariant();
        if (xgboostNotSupported) {
            return;
        }
        if (System.getProperty("os.name").toLowerCase().startsWith("windows")) {
            xgboostNotSupported = true;
            System.err.println("Warning: At the moment, your platform (Windows) is not supported by the official XGBoost maven package; ML-based suggestion reordering is disabled.");
            return;
        }
        if (autoCorrectThreshold.containsKey(langCode) && modelClasses.containsKey(langCode) && JLanguageTool.getDataBroker().resourceExists(XGBoostSuggestionsOrderer.getModelPath(this.language))) {
            try {
                Booster model = modelPool.borrowObject(this.language);
                if (model != null) {
                    modelPool.returnObject(this.language, model);
                    this.modelAvailableForLanguage = true;
                }
            }
            catch (ExceptionInInitializerError | NoClassDefFoundError | UnsatisfiedLinkError e) {
                logger.warn("At the moment, your platform (Windows?) or architecture (32 bit?) is not supported by the official XGBoost maven package; ML-based suggestion reordering is disabled.", e);
                xgboostNotSupported = true;
            }
            catch (Exception e) {
                logger.warn("Could not load spelling suggestion ranking model for language " + this.language, e);
            }
        }
    }

    @Override
    protected void initParameters() {
        this.topN = 5;
        this.score = "noop";
        this.mistakeProb = 0.0;
    }

    @Override
    public boolean isMlAvailable() {
        return super.isMlAvailable() && this.modelAvailableForLanguage;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public List<SuggestedReplacement> orderSuggestions(List<String> suggestions, String word, AnalyzedSentence sentence, int startPos) {
        List<SuggestedReplacement> list;
        if (!this.isMlAvailable()) {
            throw new IllegalStateException("Illegal call to orderSuggestions() - isMlAvailable() returned false.");
        }
        long featureStartTime = System.currentTimeMillis();
        String langCode = this.language.getShortCodeWithCountryAndVariant();
        Pair<List<SuggestedReplacement>, SortedMap<String, Float>> candidatesAndFeatures = this.computeFeatures(suggestions, word, sentence, startPos);
        List<SuggestedReplacement> candidates = candidatesAndFeatures.getLeft();
        SortedMap<String, Float> matchFeatures = candidatesAndFeatures.getRight();
        List suggestionFeatures = candidates.stream().map(SuggestedReplacement::getFeatures).collect(Collectors.toList());
        if (candidates.isEmpty()) {
            return Collections.emptyList();
        }
        if (candidates.size() != suggestionFeatures.size()) {
            throw new RuntimeException(String.format("Mismatch between candidates and corresponding feature list: length %d / %d", candidates.size(), suggestionFeatures.size()));
        }
        int numFeatures = matchFeatures.size() + this.topN * ((SortedMap)suggestionFeatures.get(0)).size();
        float[] data = new float[numFeatures];
        int featureIndex = 0;
        int expectedMatchFeatures = matchFeatureCount.getOrDefault(langCode, -1);
        int expectedCandidateFeatures = candidateFeatureCount.getOrDefault(langCode, -1);
        if (matchFeatures.size() != expectedMatchFeatures) {
            logger.warn(String.format("Match features '%s' do not have expected size %d.", matchFeatures, expectedMatchFeatures));
        }
        for (Map.Entry<String, Float> feature : matchFeatures.entrySet()) {
            data[featureIndex++] = feature.getValue().floatValue();
        }
        for (SortedMap candidateFeatures : suggestionFeatures) {
            if (candidateFeatures.size() != expectedCandidateFeatures) {
                logger.warn(String.format("Candidate features '%s' do not have expected size %d.", candidateFeatures, expectedCandidateFeatures));
            }
            for (Map.Entry feature : candidateFeatures.entrySet()) {
                data[featureIndex++] = ((Float)feature.getValue()).floatValue();
            }
        }
        List<Integer> labels = modelClasses.get(langCode);
        Booster model = null;
        try {
            long modelStartTime = System.currentTimeMillis();
            model = modelPool.borrowObject(this.language);
            DMatrix matrix = new DMatrix(data, 1, numFeatures);
            long predictStartTime = System.currentTimeMillis();
            float[][] output = model.predict(matrix);
            if (output.length != 1) {
                throw new XGBoostError(String.format("XGBoost returned array with first dimension of length %d, expected 1.", output.length));
            }
            float[] probabilities = output[0];
            if (probabilities.length != labels.size()) {
                throw new XGBoostError(String.format("XGBoost returned array with second dimension of length %d, expected %d.", probabilities.length, labels.size()));
            }
            for (int candidateIndex = 0; candidateIndex < candidates.size(); ++candidateIndex) {
                int labelIndex = labels.indexOf(candidateIndex);
                float prob = 0.0f;
                if (labelIndex != -1) {
                    prob = probabilities[labelIndex];
                }
                candidates.get(candidateIndex).setConfidence(Float.valueOf(prob));
            }
        }
        catch (XGBoostError xgBoostError) {
            logger.error("Error while applying XGBoost model to spelling suggestions", xgBoostError);
            list = candidates;
            return list;
        }
        catch (Exception e) {
            logger.error("Error while loading XGBoost model for spelling suggestions", e);
            list = candidates;
            return list;
        }
        finally {
            if (model != null) {
                try {
                    modelPool.returnObject(this.language, model);
                }
                catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
        }
        candidates.sort(Collections.reverseOrder(Comparator.comparing(SuggestedReplacement::getConfidence)));
        return candidates;
    }

    @Override
    public boolean shouldAutoCorrect(List<SuggestedReplacement> rankedSuggestions) {
        if (rankedSuggestions.isEmpty() || rankedSuggestions.stream().anyMatch(s -> s.getConfidence() == null)) {
            return false;
        }
        float threshold = autoCorrectThreshold.getOrDefault(this.language.getShortCodeWithCountryAndVariant(), Float.valueOf(Float.MAX_VALUE)).floatValue();
        return rankedSuggestions.get(0).getConfidence().floatValue() >= threshold;
    }

    static {
        List<Integer> defaultClasses = Arrays.asList(-1, 0, 1, 2, 3, 4);
        autoCorrectThreshold.put("en-US", Float.valueOf(0.99897194f));
        modelClasses.put("en-US", defaultClasses);
        candidateFeatureCount.put("en-US", 10);
        matchFeatureCount.put("en-US", 1);
    }
}

